import os, json
import pdb

import constants

current_dir = os.path.dirname(os.path.abspath(__file__))


def find_category_aspects(cate_att_tags, category):
    att_tags = cate_att_tags[category]
    aspects = att_tags.keys()
    return list(aspects)


def get_tag_group_map(group_tag_file):
    tag_group_map = {}
    for group, tag_list in group_tag_file.items():
        for tag in tag_list:
            tag_group_map[tag] = group
    return tag_group_map


def trans_word_group(word_groups):
    if len(word_groups) == 1:
        if "\\" in word_groups:
            updated_word_groups = eval("".join(word_groups.split("\\")))
    else:
        updated_word_groups = {}
        for line in word_groups:
            try:
                a = eval("{" + line + "}")
            except Exception:
                continue
            group = list(a.keys())[0]
            group_words = list(a.values())[0]
            updated_word_groups[group] = group_words
    return updated_word_groups


if __name__ == "__main__":

    # FUNCTIONS: Replace all tags in the targeted collection with corresponding group words, aggregate counts of the same word group, update the cate_att_tag_count.json file for next-stage of charts and description generation. 
    # INPUT 1: "0_cate_att_tag_counts.json" of specific collection
    #       2:  word grouping dictionaries for different aspects and categories, generated by "1_2_tag_word_grouping.py"
    # OUTPUT: "1_cate_att_tag_grouped_counts.json" of specific collection, with tag words replaced by corresponding word group

    grouped_tag_path = os.path.join(current_dir, "1_outfit_tagging_res_garment_level_refined/grouped_tags/")
    # collections_title = "2023" # "2019_2023_ss_all"
    collection_data_path = grouped_tag_path + constants.collections_title + "/"
    group_data_path = os.path.join(current_dir, "../data/grouped_tags/2019_2023_ss_all/")
    cate_att_tags_all = json.load(open(collection_data_path + "0_cate_att_tag_counts.json"))
    cate_att_tags_new_file = collection_data_path + "1_cate_att_tag_grouped_counts.json"
    if os.path.exists(cate_att_tags_new_file):
        cate_att_tags_new = json.load(open(cate_att_tags_new_file))
    else:
        cate_att_tags_new = {}

    # file to save those failed aspects, which might need to check pre-stage files
    faild_aspect_file = grouped_tag_path + "faild_aspects.txt"
    target_categories = list(cate_att_tags_all.keys())

    for target_category in target_categories:
        if target_category not in cate_att_tags_new:
            cate_att_tags_new[target_category] = {}
        target_aspects = find_category_aspects(cate_att_tags_all, target_category)
        for target_aspect in target_aspects:
            if target_aspect in cate_att_tags_new[target_category]:
                print("done %s %s" % (target_category, target_aspect))
                continue
            print("summary %s %s" % (target_category, target_aspect))
            cate_att_tags_new[target_category][target_aspect] = {}
            try:
                group_tag_file = json.load(open(group_data_path + "%s_%s.dict" % (target_category, target_aspect), "r"))
            except Exception:
                try:
                    group_tag_file = open(group_data_path + "%s_%s.dict" % (target_category, target_aspect),
                                          "r").readlines()
                except Exception:
                    f = open(faild_aspect_file, "a")
                    f.write("%s %s %s \n" % (constants.collections_title, target_category, target_aspect))
                    f.close()
            if type(group_tag_file) is list:
                try:
                    group_tag_file = trans_word_group(group_tag_file)
                except Exception:
                    pdb.set_trace()
                    continue
            tag_group_map = get_tag_group_map(group_tag_file)
            cate_att_tags = cate_att_tags_all[target_category][target_aspect]
            for att, att_cnt in cate_att_tags.items():
                if att in tag_group_map:
                    new_att = tag_group_map[att]
                else:
                    new_att = "na"
                if new_att not in cate_att_tags_new[target_category][target_aspect]:
                    cate_att_tags_new[target_category][target_aspect][new_att] = 0
                cate_att_tags_new[target_category][target_aspect][new_att] += att_cnt

            json.dump(cate_att_tags_new, open(cate_att_tags_new_file, "w"))
